{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/sharedata/mdy/miniforge/envs/cuda128/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import triton\n",
    "import triton.language as tl\n",
    "from copy import deepcopy\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "from transformers import Qwen2ForCausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 代码是根据trl仓库改的，因此triton实现也是根据这个仓库的实现方式进行改进的\n",
    "# 最主要的就是p(x)和p_old(x)是一样的\n",
    "def get_log_probs(logits, input_ids):\n",
    "    per_token_logps = []\n",
    "    for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):\n",
    "        log_probs = logits_row.log_softmax(dim=-1)\n",
    "        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)\n",
    "        per_token_logps.append(token_log_prob)\n",
    "    return torch.stack(per_token_logps)\n",
    "\n",
    "def torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):\n",
    "    # logits通过以下计算得到\n",
    "    # logits_to_keep = completion_ids.size(1)\n",
    "    # logits = model(input_ids=input_ids, \n",
    "    #             attention_mask=attention_mask,\n",
    "    #             logits_to_keep=logits_to_keep + 1).logits\n",
    "    # 传ref_logp（bs*L）而不是ref_logits的原因是，该值可以在inference_mode()下得到，\n",
    "    # 无需保存中间结果，ref_logits会浪费显存\n",
    "    assert logits.is_contiguous() and ref_logp.is_contiguous()\n",
    "    logits = logits[:, :-1] # 错一位，对应下一个输入token的概率         \n",
    "    per_token_logps = get_log_probs(logits, input_ids) # logits是需要计算梯度，因此会保存中间结果log_probs\n",
    "    ref_per_token_logps = ref_logp\n",
    "    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1\n",
    "    # old_model一步一更新， p(x) 和 p_old(x) 是一样的。\n",
    "    # 下面这个指数部分等于1，做这一步，是让logits挂到奖励部分的Loss，反向传播时，奖励会对logits产生一部分梯度\n",
    "    per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)\n",
    "    per_token_loss = -(per_token_loss - beta * per_token_kl)\n",
    "    if completion_mask is not None:\n",
    "        per_token_loss *= completion_mask \n",
    "        if save_kl:\n",
    "            per_token_kl *= completion_mask\n",
    "    return per_token_loss if not save_kl else (per_token_loss, per_token_kl)# 外部进行reduce"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# skip mask part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)\n",
    "#                  for bsz in [2048*(2**i) for i in range(5)]\n",
    "#                  for ns in [1,2,4]\n",
    "#                  for nw in [8, 16, 32]\n",
    "#                  ], key=['N']\n",
    "#                  )\n",
    "@triton.jit\n",
    "def _grpo_loss_fwd(LOGITS, REF_LOGP, INPUT_IDS, ADVANTAGES, MASK, BETA,\n",
    "                    LOSS, LSE, SAVE_KL,\n",
    "                    M, N, L, INPUT_IDS_START_INDEX,\n",
    "                    BLOCK_SIZE: tl.constexpr\n",
    "                    ):\n",
    "    row_idx = tl.program_id(0)\n",
    "    # 因为最后一个位置不需要计算，实际上Logits是一个B*(L+1)行的向量，而我们只启动了B*L个程序\n",
    "    # 比如3*4*vocab_size，每第4个位置不需要计算\n",
    "    # row_idx从0开始，如果到第2行第一个为止，row_id为3，而真实的行id应该是4。\n",
    "    # 因此用off_b去记录一个偏移量\n",
    "    off_b = row_idx // L    \n",
    "    N = tl.cast(N, tl.int64)\n",
    "\n",
    "    LOGITS += N * (row_idx + off_b) # 加上偏移量\n",
    "    REF_LOGP += row_idx\n",
    "    # 同样input_ids前面介绍时也有多余的prompt部分\n",
    "    # 比如prompt长度为64，第1行的起始位置应该从64开始\n",
    "    INPUT_IDS += row_idx + (off_b+1) * INPUT_IDS_START_INDEX\n",
    "    LOSS += row_idx\n",
    "    LSE += row_idx\n",
    "    ADVANTAGES += off_b\n",
    "    \n",
    "    MASK += row_idx\n",
    "    not_skip = tl.load(MASK)# 跳过padding的部分，节约时间\n",
    "    if not_skip == 1:       # 尤其是output长短不一时，都会pad到最长的那个，会浪费很多计算资源\n",
    "        base_cols = tl.arange(0, BLOCK_SIZE)\n",
    "        # 没啥好说的，计算两个lse，online-softmax那一套\n",
    "        m_i = -float(\"inf\")\n",
    "        l_i = 0.0\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)\n",
    "            m_ij = tl.max(logits)\n",
    "            new_m_i = tl.maximum(m_i, m_ij)\n",
    "            l_i = l_i * tl.exp(m_i - new_m_i) + tl.sum(tl.exp(logits - new_m_i))\n",
    "            m_i = new_m_i\n",
    "        lse = tl.log(l_i) + m_i\n",
    "\n",
    "        # 有了lse，直接读取input_ids对应的logits即可，一个标量\n",
    "        idx = tl.load(INPUT_IDS)\n",
    "        x = tl.load(LOGITS+idx).to(tl.float32)\n",
    "        advantage = tl.load(ADVANTAGES).to(tl.float32)\n",
    "        ref_logp = tl.load(REF_LOGP)\n",
    "        logp = x - lse\n",
    "        diff = ref_logp - logp\n",
    "        kl = tl.exp(diff) - diff - 1\n",
    "        # 因为我们知道 torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)\n",
    "        # 实际上等于 1 * advantages.unsqueeze(1)\n",
    "        # loss我们直接减去一个advantage\n",
    "        loss = kl * BETA - advantage\n",
    "        tl.store(LOSS, loss)\n",
    "        tl.store(LSE, lse)\n",
    "        if SAVE_KL:\n",
    "            tl.store(LOSS+M, kl)\n",
    "\n",
    "\n",
    "# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)\n",
    "#                  for bsz in [2048*(2**i) for i in range(5)]\n",
    "#                  for ns in [1,2,4]\n",
    "#                  for nw in [8, 16, 32]\n",
    "#                  ], key=['N']\n",
    "#                  )\n",
    "@triton.jit\n",
    "def _grpo_loss_bwd(DLOSS, DLOGITS, \n",
    "                   LOGITS, REF_LOGP, INPUT_IDS, ADVANTAGES, MASK, BETA,\n",
    "                    LSE,\n",
    "                    N, L, INPUT_IDS_START_INDEX,\n",
    "                    BLOCK_SIZE: tl.constexpr\n",
    "                    ):\n",
    "    # 与forward部分如出一辙\n",
    "    row_idx = tl.program_id(0)\n",
    "    off_b = row_idx // L\n",
    "    N = tl.cast(N, tl.int64)\n",
    "\n",
    "    DLOSS += row_idx\n",
    "    DLOGITS += N * (row_idx + off_b)\n",
    "    LOGITS += N * (row_idx + off_b)\n",
    "    REF_LOGP += row_idx\n",
    "    INPUT_IDS += row_idx + (off_b+1) * INPUT_IDS_START_INDEX\n",
    "    LSE += row_idx\n",
    "    ADVANTAGES += off_b\n",
    "    base_cols = tl.arange(0, BLOCK_SIZE)\n",
    "\n",
    "    MASK += row_idx\n",
    "    not_skip = tl.load(MASK)\n",
    "    if not_skip == 1:\n",
    "        dloss = tl.load(DLOSS).to(tl.float32)\n",
    "        lse = tl.load(LSE)\n",
    "        idx = tl.load(INPUT_IDS)\n",
    "        x = tl.load(LOGITS+idx).to(tl.float32)\n",
    "        advantage = tl.load(ADVANTAGES).to(tl.float32)\n",
    "        ref_logp = tl.load(REF_LOGP)\n",
    "        logp = x - lse\n",
    "\n",
    "        # 算dlogp\n",
    "        dlogp = (BETA * (-1.0 * tl.exp(ref_logp - logp) + 1) \\\n",
    "                        - advantage) \\\n",
    "                        * dloss\n",
    "\n",
    "        # 用dlogp再去算dlogits\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)\n",
    "            probs = tl.exp(logits - lse)\n",
    "            dlogits = tl.where(cols==idx, 1-probs, -probs) * dlogp\n",
    "            # DLOGITS的内存就对应REF_LOGITS，废物再利用\n",
    "            tl.store(DLOGITS+cols, dlogits, mask=mask)\n",
    "    else:\n",
    "        dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            # DLOGITS的内存就对应REF_LOGITS，废物再利用\n",
    "            tl.store(DLOGITS+cols, dlogits, mask=mask)\n",
    "\n",
    "\n",
    "class _GrpoLoss(torch.autograd.Function):\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace):\n",
    "        # 设计思路：\n",
    "        # 为什么输入是模型的原始输出，而不是logits[:, :-1]？\n",
    "        # triton一般需要tensor是连续的，如果不连续，处理起来很麻烦\n",
    "        # 而logits[:, :-1].contiguous() 会创建一个新的张量，增加显存开销\n",
    "        # 实际上我们在内部计算时，忽略掉最后一个位置即可\n",
    "        assert logits.is_contiguous() and ref_logp.is_contiguous()\n",
    "        ctx.input_shape = logits.shape\n",
    "        B, L_ADD_1, N = ctx.input_shape\n",
    "        L = L_ADD_1 - 1 \n",
    "        M = B * L # 我们实际需要计算的长度是 B * (L + 1 - 1)个行向量户即可\n",
    "        # input_ids也需要是连续的， 如果是 input_ids[:, -logits_to_keep:]，这就不是连续的了\n",
    "        # 当然也可以是input_ids[:, -logits_to_keep:].contiguous()，这少一个vocab_size维度，基本无开销\n",
    "        # 但是我们也可以记录下output的起始位置，跳过prompt部分即可\n",
    "        input_ids_start_index = input_ids.size(1) - L  \n",
    "        # 下面都用fp32进行存储，因为都没有vocab_size这个维度，基本无额外显存开销，但是大大提高精度\n",
    "        if not save_kl:\n",
    "            loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) \n",
    "        else:\n",
    "            loss = torch.zeros(B*2, L, device=logits.device, dtype=torch.float32) # 后一半存kl\n",
    "        lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)  # 等价 max(x) + logsumexp(x)，用于backward的快速计算\n",
    "\n",
    "        if completion_mask is None:\n",
    "            completion_mask = torch.ones(B,L, device=logits.device, dtype=torch.int32)\n",
    "        else:\n",
    "            loss[:B].masked_fill_(completion_mask.logical_not(), 0)\n",
    "        kwargs = {'BLOCK_SIZE': 8192, 'num_warps': 8, 'num_stages':1}\n",
    "        _grpo_loss_fwd[(M,)](logits, ref_logp, input_ids, advantages, completion_mask, beta,\n",
    "                            loss, lse, save_kl,\n",
    "                            M, N, L, input_ids_start_index,\n",
    "                            **kwargs,\n",
    "                            )\n",
    "        ctx.beta = beta\n",
    "        ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)\n",
    "        ctx.ref_logp = ref_logp\n",
    "        ctx.inplace = inplace\n",
    "        return loss\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, dloss):\n",
    "        # logits对应的grad来自两个部分，reward部分和kl部分\n",
    "        # print(dloss.view(-1).stride(), dloss.shape)\n",
    "        if not dloss.is_contiguous():\n",
    "            dloss = dloss.contiguous()\n",
    "        \n",
    "        lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors\n",
    "        B, L_ADD_1, N = ctx.input_shape\n",
    "        L = L_ADD_1 - 1\n",
    "        M = B * L\n",
    "        input_ids_start_index = input_ids.size(1) - L\n",
    "        # 实际上当我们读取一些logits的值后，这个张量就一点用都没有了\n",
    "        # 我们直接把logits的grad用logits存储，直接废物再利用，节省显存\n",
    "        dlogits = logits if ctx.inplace else torch.empty_like(logits)\n",
    "        kwargs = {'BLOCK_SIZE': 8192, 'num_warps': 32, 'num_stages':4}\n",
    "        _grpo_loss_bwd[(M,)](dloss, dlogits, \n",
    "                            logits, ctx.ref_logp, input_ids, advantages, completion_mask, ctx.beta,\n",
    "                            lse,\n",
    "                            N, L, input_ids_start_index,\n",
    "                            **kwargs\n",
    "                                )\n",
    "        # 最后一个位置的token并没有参与计算，梯度需要设置为0\n",
    "        # 因为empty的初始化或者ref_logits的初始化，该位置都不是0，需要手动设置下\n",
    "        dlogits[:, -1, :] = 0\n",
    "        return dlogits.view(*ctx.input_shape), None, None, None, None, None, None, None\n",
    "\n",
    "def triton_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False, inplace=True) -> torch.Tensor:\n",
    "    '''\n",
    "    compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)\n",
    "\n",
    "    Args:\n",
    "        logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]\n",
    "        ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]\n",
    "        input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids\n",
    "        advantages: Tensor, [B], the advantages of each prompt\n",
    "        beta: float, the weight of kl loss\n",
    "        completion_mask: Tensor, loss mask\n",
    "        save_kl: bool, if true will save kl\n",
    "        inplace: bool, if true, in backward use logits to store the logits's grad, it can save memory\n",
    "\n",
    "    Retutn:\n",
    "        loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part\n",
    "\n",
    "    NOTE: logits(ref_logits) is computed by these steps\n",
    "        logits_to_keep = completion_ids.size(1)\n",
    "\n",
    "        def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):\n",
    "            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded\n",
    "            logits = model(\n",
    "                input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1\n",
    "            ).logits\n",
    "            return logits\n",
    "            \n",
    "        logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)\n",
    "    '''\n",
    "    out = _GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace)\n",
    "    if not save_kl:\n",
    "        return out\n",
    "    else:\n",
    "        return out.chunk(2, axis=0)\n",
    "    \n",
    "def get_random_ref_log_probs(logits, input_ids):\n",
    "    with torch.inference_mode():\n",
    "        logits = logits[:,:-1]\n",
    "        per_token_logps = []\n",
    "        for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):\n",
    "            log_probs = torch.randn_like(logits_row).log_softmax(dim=-1)\n",
    "            token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)\n",
    "            per_token_logps.append(token_log_prob)\n",
    "        torch.cuda.empty_cache()\n",
    "        return torch.stack(per_token_logps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 精度测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = torch.bfloat16\n",
    "device = 'cuda'\n",
    "bs, seq_len, vocab_size = 8, 1024, 150000\n",
    "logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype) # 最后一个位置是eos token的logits，计算时会扔掉\n",
    "logits.requires_grad_(True)\n",
    "copy_logits = deepcopy(logits)\n",
    "advantages = torch.randn(bs, device=device, dtype=torch.float32)\n",
    "input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device) # 64是随便设置的，表示prompt ids的长度，剩下是output\n",
    "ref_logp = get_random_ref_log_probs(logits, input_ids)\n",
    "beta = 0.04\n",
    "completion_mask = torch.ones(bs, seq_len, dtype=torch.int32, device=device)\n",
    "completion_mask[::2, seq_len//2:] = 0  # 假设有一半的后半部分都是padding\n",
    "save_kl = True\n",
    "\n",
    "gold_logits = logits.detach().clone().float()\n",
    "gold_logits.requires_grad_(True)\n",
    "gold_ref_logp= deepcopy(ref_logp).float()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0,) torch.Size([8, 1024])\n"
     ]
    }
   ],
   "source": [
    "y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl=False)\n",
    "y2.sum().backward()\n",
    "# (y2.sum(-1) / completion_mask.sum(-1)).mean().backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================== KL:\n",
      "tensor(3.7727, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.0271, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "tensor(0.0003, device='cuda:0', grad_fn=<MaxBackward1>) tensor(2.6195e-07, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "================================================== Loss:\n",
      "tensor(0.1397, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.0011, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "tensor(1.2875e-05, device='cuda:0', grad_fn=<MaxBackward1>) tensor(1.1167e-08, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "================================================== Grad:\n",
      "tensor(0.1768, device='cuda:0') tensor(5.9009e-08, device='cuda:0')\n",
      "tensor(0.0132, device='cuda:0') tensor(8.0266e-09, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.empty_cache()\n",
    "y1 = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)           # torch bf16\n",
    "y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl=save_kl)     # triton\n",
    "y3 = torch_grpo_loss(gold_logits, gold_ref_logp, input_ids, advantages, beta, completion_mask, save_kl) # torch fp32\n",
    "if save_kl:\n",
    "    y1, kl1 = y1\n",
    "    y2, kl2 = y2\n",
    "    y3, kl3 = y3\n",
    "    print('='*50 + ' KL:')\n",
    "    print((kl1-kl3).abs().max(), (kl1-kl3).abs().mean())  # kl, torch bf16 vs torch fp32,\n",
    "    print((kl2-kl3).abs().max(), (kl2-kl3).abs().mean())  # kl, triton vs torch fp32,\n",
    "dy = torch.randn_like(y1)\n",
    "y1.backward(dy)\n",
    "y2.backward(dy)\n",
    "y3.backward(dy.float())\n",
    "print('='*50 + ' Loss:')\n",
    "print((y1-y3).abs().max(), (y1-y3).abs().mean())  # fwd, torch bf16 vs torch fp32\n",
    "print((y2-y3).abs().max(), (y2-y3).abs().mean())  # fwd, triton vs torch fp32\n",
    "print('='*50 + ' Grad:')\n",
    "print((logits.grad - gold_logits.grad).abs().max(), (logits.grad - gold_logits.grad).abs().mean()) # bwd, torch bf16 vs torch fp32\n",
    "print((copy_logits.grad - gold_logits.grad).abs().max(), (copy_logits.grad - gold_logits.grad).abs().mean()) # bwd, triton vs torch fp32\n",
    "# 多尝试几次，使用triton计算的结果更精确， 误差更小一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.588759660720825\n",
      "Triton autotuning for function _grpo_loss_fwd finished after 12.51s; best config selected: BLOCK_SIZE: 8192, num_warps: 8, num_ctas: 1, num_stages: 1, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;\n",
      "0.514387845993042\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda:torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)))\n",
    "print(triton.testing.do_bench(lambda:triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.383981704711914\n",
      "Triton autotuning for function _grpo_loss_bwd finished after 13.34s; best config selected: BLOCK_SIZE: 8192, num_warps: 32, num_ctas: 1, num_stages: 4, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;\n",
      "1.1460970640182495\n"
     ]
    }
   ],
   "source": [
    "# 重新生成数据，直接运行这个代码\n",
    "y1 = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)\n",
    "y2 = triton_grpo_loss(copy_logits, ref_logp, input_ids, advantages, beta, completion_mask, inplace=False)\n",
    "dy = torch.randn_like(y1)\n",
    "print(triton.testing.do_bench(lambda:y1.backward(dy, retain_graph=True), grad_to_none=[logits]))\n",
    "print(triton.testing.do_bench(lambda:y2.backward(dy, retain_graph=True), grad_to_none=[copy_logits]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 显存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logits显存占用： 4.579871892929077 G\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 208/500 [00:05<00:07, 38.79it/s]"
     ]
    }
   ],
   "source": [
    "# 刷新运行\n",
    "dtype = torch.bfloat16\n",
    "device = 'cuda'\n",
    "bs, seq_len, vocab_size = 8, 2048, 150000\n",
    "logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits.requires_grad_(True)\n",
    "advantages = torch.randn(bs, device=device, dtype=torch.float32)\n",
    "input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device)\n",
    "ref_logp = get_random_ref_log_probs(logits, input_ids)\n",
    "beta = 0.04\n",
    "iters = 500\n",
    "dy = torch.randn(bs, seq_len, device=device, dtype=dtype)\n",
    "factor = 4 if dtype == torch.float32 else 2\n",
    "print('logits显存占用：',(bs * (seq_len+1) * vocab_size) / (1024)**3 * factor,\"G\")\n",
    "time.sleep(3) # 初始化时观察显存，可以用nvitop\n",
    "for i in tqdm(range(iters)):\n",
    "    y = torch_grpo_loss(logits, ref_logp, input_ids, advantages, beta)\n",
    "    y.backward(dy)\n",
    "    logits.grad = None\n",
    "# 5.7G -> 24.6 G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logits显存占用： 4.579871892929077 G\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  2.61it/s]"
     ]
    }
   ],
   "source": [
    "# 刷新运行\n",
    "dtype = torch.bfloat16\n",
    "device = 'cuda'\n",
    "bs, seq_len, vocab_size = 8, 2048, 150000\n",
    "logits = torch.randn(bs, seq_len + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits.requires_grad_(True)\n",
    "advantages = torch.randn(bs, device=device, dtype=torch.float32)\n",
    "input_ids = torch.randint(0, vocab_size-1, (bs, seq_len + 64), device=device)\n",
    "ref_logp = get_random_ref_log_probs(logits, input_ids)\n",
    "beta = 0.04\n",
    "completion_mask = torch.ones(bs, seq_len, dtype=torch.int32, device=device)\n",
    "completion_mask[::2, seq_len//2:] = 0 \n",
    "iters = 1\n",
    "dy = torch.randn(bs, seq_len, device=device, dtype=dtype)\n",
    "factor = 4 if dtype == torch.float32 else 2\n",
    "print('logits显存占用：',(bs * (seq_len+1) * vocab_size) / (1024)**3 * factor,\"G\")\n",
    "time.sleep(3) # 初始化时观察显存，可以用nvitop\n",
    "pbar = tqdm(total=iters)\n",
    "for i in range(iters):\n",
    "    y = triton_grpo_loss(logits, ref_logp, input_ids, advantages, beta, completion_mask, inplace=True)\n",
    "    y.backward(dy)\n",
    "    # logits.grad = None\n",
    "    pbar.update(1)\n",
    "# 5.7G -> 5.7G, 基本无任何额外开销"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_logp = get_random_ref_log_probs(logits, input_ids).clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140140152356864"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.grad.data_ptr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140140152356864"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.data_ptr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.grad.data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
